{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import datasets, layers, models, callbacks\n",
    "from tensorflow.keras.datasets import mnist\n",
    "\n",
    "import os\n",
    "file_path = os.path.abspath('./mnist.npz')\n",
    "\n",
    "(train_x, train_y), (test_x, test_y) = datasets.mnist.load_data(path=file_path)\n",
    "train_y, test_y = train_y[:1000], test_y[:1000]\n",
    "train_x = train_x[:1000].reshape(-1, 28 * 28) / 255.0\n",
    "test_x = test_x[:1000].reshape(-1, 28 * 28) / 255.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 搭建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_model():\n",
    "    model = models.Sequential([\n",
    "        layers.Dense(512, activation='relu', input_shape=(784,)),\n",
    "        layers.Dropout(0.2),\n",
    "        layers.Dense(10, activation='softmax')\n",
    "    ])\n",
    "\n",
    "    model.compile(optimizer='adam', metrics=['accuracy'],\n",
    "                  loss='sparse_categorical_crossentropy')\n",
    "\n",
    "    return model\n",
    "\n",
    "def evaluate(target_model):\n",
    "    _, acc = target_model.evaluate(test_x, test_y)\n",
    "    print(\"Restore model, accuracy: {:5.2f}%\".format(100*acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自动保存 checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0713 00:06:20.997914 140735530943360 callbacks.py:859] `period` argument is deprecated. Please use `save_freq` to specify the frequency in number of samples seen.\n",
      "W0713 00:06:21.173481 140735530943360 deprecation.py:323] From /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 00010: saving model to training_2/cp-0010.ckpt\n",
      "\n",
      "Epoch 00020: saving model to training_2/cp-0020.ckpt\n",
      "\n",
      "Epoch 00030: saving model to training_2/cp-0030.ckpt\n",
      "\n",
      "Epoch 00040: saving model to training_2/cp-0040.ckpt\n",
      "\n",
      "Epoch 00050: saving model to training_2/cp-0050.ckpt\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x13de39e80>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 存储模型的文件名，语法与 str.format 一致\n",
    "# period=10：每 10 epochs 保存一次\n",
    "checkpoint_path = \"training_2/cp-{epoch:04d}.ckpt\"\n",
    "checkpoint_dir = os.path.dirname(checkpoint_path)\n",
    "cp_callback = callbacks.ModelCheckpoint(\n",
    "    checkpoint_path, verbose=1, save_weights_only=True, period=10)\n",
    "\n",
    "model = create_model()\n",
    "model.save_weights(checkpoint_path.format(epoch=0))\n",
    "model.fit(train_x, train_y, epochs=50, callbacks=[cp_callback],\n",
    "          validation_data=(test_x, test_y), verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000/1000 [==============================] - 0s 94us/sample - loss: 0.5232 - accuracy: 0.8720\n",
      "Restore model, accuracy: 87.20%\n"
     ]
    }
   ],
   "source": [
    "latest = tf.train.latest_checkpoint(checkpoint_dir)\n",
    "# 'training_2/cp-0050.ckpt'\n",
    "model = create_model()\n",
    "model.load_weights(latest)\n",
    "evaluate(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 手动保存权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000/1000 [==============================] - 0s 90us/sample - loss: 0.5232 - accuracy: 0.8720\n",
      "Restore model, accuracy: 87.20%\n"
     ]
    }
   ],
   "source": [
    "# 手动保存权重\n",
    "model.save_weights('./checkpoints/mannul_checkpoint')\n",
    "model = create_model()\n",
    "model.load_weights('./checkpoints/mannul_checkpoint')\n",
    "evaluate(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 保存整个模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### HDF5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save('my_model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0713 00:06:28.440529 140735530943360 hdf5_format.py:192] Error in loading the saved optimizer state. As a result, your model is starting with a freshly initialized optimizer.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000/1000 [==============================] - 0s 91us/sample - loss: 0.5232 - accuracy: 0.8720\n",
      "Restore model, accuracy: 87.20%\n"
     ]
    }
   ],
   "source": [
    "new_model = models.load_model('my_model.h5')\n",
    "evaluate(new_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### saved_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0713 00:06:29.094913 140735530943360 deprecation.py:323] From /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.\n",
      "W0713 00:06:29.096133 140735530943360 export_utils.py:182] Export includes no default signature!\n",
      "W0713 00:06:29.277875 140735530943360 util.py:244] Unresolved object in checkpoint: (root).optimizer.iter\n",
      "W0713 00:06:29.278687 140735530943360 util.py:244] Unresolved object in checkpoint: (root).optimizer.beta_1\n",
      "W0713 00:06:29.279356 140735530943360 util.py:244] Unresolved object in checkpoint: (root).optimizer.beta_2\n",
      "W0713 00:06:29.280627 140735530943360 util.py:244] Unresolved object in checkpoint: (root).optimizer.decay\n",
      "W0713 00:06:29.281178 140735530943360 util.py:244] Unresolved object in checkpoint: (root).optimizer.learning_rate\n",
      "W0713 00:06:29.281950 140735530943360 util.py:252] A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/alpha/guide/checkpoints#loading_mechanics for details.\n",
      "W0713 00:06:29.402854 140735530943360 export_utils.py:182] Export includes no default signature!\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(1000, 10)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import time\n",
    "saved_model_path = \"./saved_models/{}\".format(int(time.time()))\n",
    "tf.keras.experimental.export_saved_model(model, saved_model_path)\n",
    "new_model = tf.keras.experimental.load_from_saved_model(saved_model_path)\n",
    "model.predict(test_x).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000/1000 [==============================] - 0s 95us/sample - loss: 0.5232 - accuracy: 0.8720\n",
      "Restore model, accuracy: 87.20%\n"
     ]
    }
   ],
   "source": [
    "new_model.compile(optimizer=model.optimizer,\n",
    "                  loss='sparse_categorical_crossentropy',\n",
    "                  metrics=['accuracy'])\n",
    "evaluate(new_model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
